import torch
import torch.utils.data
import numpy as np
from torch import nn, optim

from vae_models.fc_model import VAE
from learned_sigma import get_distribution_class
from distributions import AttrDict


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def __init__(self, n_channels):
        super(UnFlatten, self).__init__()
        self.n_channels = n_channels
    
    def forward(self, input):
        size = int((input.size(1) // self.n_channels) ** 0.5)
        return input.view(input.size(0), self.n_channels, size, size)



class VAE_Conv(VAE):
    """
    https://github.com/vdumoulin/conv_arithmetic
    """

    @staticmethod
    def get_encoder(img_channels, filters_m):
        encoder = nn.Sequential(
            nn.Conv2d(img_channels, filters_m, (3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(filters_m, 2 * filters_m, (4, 4), stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * filters_m, 4 * filters_m, (5, 5), stride=2, padding=2),
            nn.ReLU(),
            Flatten()
        )
        return encoder

    @staticmethod
    def get_decoder(filters_m, out_channels, activation=None):
        decoder = nn.Sequential(
            UnFlatten(4 * filters_m),
            nn.ConvTranspose2d(4 * filters_m, 2 * filters_m, (6, 6), stride=2, padding=2),
            nn.ReLU(),
            nn.ConvTranspose2d(2 * filters_m, filters_m, (6, 6), stride=2, padding=2),
            nn.ReLU(),
            nn.ConvTranspose2d(filters_m, out_channels, (5, 5), stride=1, padding=2),
        )
        if activation is not None:
            decoder.final_activation = activation
    
        return decoder

    def build_network(self):
        img_size = 28
        _, self.distr_params, activation = get_distribution_class(self.args.distribution, self.args.sigma_mode)
        filters_m = self.args.n_filters
        
        self.encoder = self.get_encoder(self.img_channels, filters_m)
        
        ## output size depends on input image size
        demo_input = torch.ones([1, self.img_channels, img_size, img_size])
        h_dim = self.encoder(demo_input).shape[1]
        print('h_dim', h_dim)
        ## map to latent z
        # h_dim = convnet_to_dense_size(img_size, encoder_params)
        self.fc11 = nn.Linear(h_dim, self.z_dim)
        self.fc12 = nn.Linear(h_dim, self.z_dim)
        
        ## decoder
        self.fc2 = nn.Linear(self.z_dim, h_dim)
        self.decoder = self.get_decoder(filters_m, self.img_channels * self.distr_params, activation)
        
        if self.args.sigma_mode == 'image':
            self.image_sigma_decoder = self.get_decoder(filters_m, self.img_channels * self.distr_params)
        
    def encode(self, x):
        h = self.encoder(x)
        return self.fc11(h), self.fc12(h)
    
    def decode_mean(self, z):
        out = self.decoder(self.fc2(z))
        output = AttrDict()
        
        if 'categorical' in self.args.distribution:
            out = out.reshape(out.shape[0:1] + (self.distr_params, -1) + out.shape[2:])
            mean = self.distr(log_p=out).mean / 2 + 0.5
            return AttrDict(log_prob=out, mle=mean)
        elif self.args.distribution == 'bernoulli':
            return AttrDict(mu=out, mle=out)
        elif self.args.distribution == 'beta':
            d = self.distr(out[:, :self.img_channels], out[:, self.img_channels:])
        else:
            d = self.distr(out, 0)
            
            if self.args.sigma_mode == 'image':
                if self.args.detach_sigma_network:
                    z = z.detach()
                output.sigma = self.image_sigma_decoder(self.fc2(z))

        # Note 'mu' actually means a parameter that is fed into distr, and mle actually means the visualized parameter
        # In practice, mle is always the mean.
        output.mu = out
        output.mle = d.mean
        output.d = d
        return output
